The purpose of this project is to develop a model that will predict the edibility of mushrooms from its features, such as cap (shape, surface, color), gill (attachment, spacing, size, color), bruises, stalk (shape, root, surface above ring, surface below ring, color above ring, color below ring), veil (type, color), ring (number, type), spore print color, population, habitat. We will be using data from Kaggle data set, and implementing multiple techniques to yield the most accurate model for this binary classification problem.
I myself am a huge mushroom hater, I can’t stand the taste of mushrooms and have no interest in them. In the past 20 years of my life, mushrooms have never been a part of my life. It was only one day when my friend in China experienced mushroom poisoning that I realized for the first time how interesting mushrooms can be. My friend had gone to Yunnan province last summer and Yunnan province is famous for its delicious mushrooms, so my friend tried them to experience the culture there. As a result, after eating some mushrooms, she experienced a very amazing feeling. She felt that the world had become colorful and felt that there were many little people around her. She went to the local hospital to know that it was mushroom poisoning, which is very common in Yunnan province. After hearing this story I thought it was very interesting and slowly became very interested in the toxicity of mushrooms. So I chose this data about the edibility of mushrooms to predict whether they are poisonous or not based on their characteristics. I think this can help many travelers to Yunnan, China, to avoid eating poisonous mushrooms again.
This data was taken from the Kaggle data set, “Mushrooms Classification,” and it a dataset of job titles and resumes by user UCI Machines Learning.
Now that we are getting famaliar with our data, let’s dive into our plans for our model in this project. First, we will clean our data and check if there is any missing observations, as well as remove any predictor variables that are unnecessary. These remaining variables will be used to predict a binary response variable named “Class”, which will detail whether or not a mushroom is edible or poisonous. We will then perform a training/test split on our data, make a recipe, and set folds for the 10-fold cross validation we will implement. Logistic Regression, k-Nearest Neighbors, Decision Tree, Random Forest models will be all used to model the training data when we finish the setup. Once we have the results from each model, we will select the one that performed the best and fit it to our test data set to find out how effective our model really is at predicting edibility of mushrooms. Let’s get started!
We will first load the data that I have collected and clean some unnecessary variables.
library(tidyverse)
library(dplyr)
library(tidymodels)
library(readr)
library(kknn)
library(janitor)
library(ISLR)
library(discrim)
library(poissonreg)
library(glmnet)
library(corrr)
library(corrplot)
library(tune)
library(xgboost)
library(vip)
library(ranger)
library(ggplot2)
library(forcats)
tidymodels_prefer()
# loading the data
mushrooms <- read.csv("/Users/zoezhou/Desktop/Pstat131 Final Project/data/mushrooms.csv")
# cleaning predictor names
mushrooms <- clean_names(mushrooms)
# Calling head() to see the first few rows
head(mushrooms)
## class cap_shape cap_surface cap_color bruises odor gill_attachment
## 1 p x s n t p f
## 2 e x s y t a f
## 3 e b s w t l f
## 4 p x y w t p f
## 5 e x s g f n f
## 6 e x y y t a f
## gill_spacing gill_size gill_color stalk_shape stalk_root
## 1 c n k e e
## 2 c b k e c
## 3 c b n e c
## 4 c n n e e
## 5 w b k t e
## 6 c b n e c
## stalk_surface_above_ring stalk_surface_below_ring stalk_color_above_ring
## 1 s s w
## 2 s s w
## 3 s s w
## 4 s s w
## 5 s s w
## 6 s s w
## stalk_color_below_ring veil_type veil_color ring_number ring_type
## 1 w p w o p
## 2 w p w o p
## 3 w p w o p
## 4 w p w o p
## 5 w p w o e
## 6 w p w o p
## spore_print_color population habitat
## 1 k s u
## 2 n n g
## 3 n n m
## 4 k s u
## 5 n a g
## 6 k n g
Now, let’s check if there is any missing data in this data set!
#find the missing data
summary(mushrooms)
## class cap_shape cap_surface cap_color
## Length:8124 Length:8124 Length:8124 Length:8124
## Class :character Class :character Class :character Class :character
## Mode :character Mode :character Mode :character Mode :character
## bruises odor gill_attachment gill_spacing
## Length:8124 Length:8124 Length:8124 Length:8124
## Class :character Class :character Class :character Class :character
## Mode :character Mode :character Mode :character Mode :character
## gill_size gill_color stalk_shape stalk_root
## Length:8124 Length:8124 Length:8124 Length:8124
## Class :character Class :character Class :character Class :character
## Mode :character Mode :character Mode :character Mode :character
## stalk_surface_above_ring stalk_surface_below_ring stalk_color_above_ring
## Length:8124 Length:8124 Length:8124
## Class :character Class :character Class :character
## Mode :character Mode :character Mode :character
## stalk_color_below_ring veil_type veil_color
## Length:8124 Length:8124 Length:8124
## Class :character Class :character Class :character
## Mode :character Mode :character Mode :character
## ring_number ring_type spore_print_color population
## Length:8124 Length:8124 Length:8124 Length:8124
## Class :character Class :character Class :character Class :character
## Mode :character Mode :character Mode :character Mode :character
## habitat
## Length:8124
## Class :character
## Mode :character
Wow! As we can see the summary of the data, there is no missing data and the data we have is very clean! So we don’t need to delete any variables of this data. But we still have to change the observation names and convert any categorical variables to factors in order for the future training.
#convert categorical variables to factors
mushrooms<- data.frame(lapply(mushrooms, factor))
sapply(mushrooms, class)
## class cap_shape cap_surface
## "factor" "factor" "factor"
## cap_color bruises odor
## "factor" "factor" "factor"
## gill_attachment gill_spacing gill_size
## "factor" "factor" "factor"
## gill_color stalk_shape stalk_root
## "factor" "factor" "factor"
## stalk_surface_above_ring stalk_surface_below_ring stalk_color_above_ring
## "factor" "factor" "factor"
## stalk_color_below_ring veil_type veil_color
## "factor" "factor" "factor"
## ring_number ring_type spore_print_color
## "factor" "factor" "factor"
## population habitat
## "factor" "factor"
## We redefine each of the category for each of the variables
levels(mushrooms$class) <- c("edible", "poisonous")
levels(mushrooms$cap_shape) <- c("bell", "conical", "flat", "knobbed", "sunken", "convex")
levels(mushrooms$cap_color) <- c("buff", "cinnamon", "red", "gray", "brown", "pink",
"green", "purple", "white", "yellow")
levels(mushrooms$cap_surface) <- c("fibrous", "grooves", "scaly", "smooth")
levels(mushrooms$bruises) <- c("no", "yes")
levels(mushrooms$odor) <- c("almond", "creosote", "foul", "anise", "musty", "none", "pungent", "spicy", "fishy")
levels(mushrooms$gill_attachment) <- c("attached", "free")
levels(mushrooms$gill_spacing) <- c("close", "crowded")
levels(mushrooms$gill_size) <- c("broad", "narrow")
levels(mushrooms$gill_color) <- c("buff", "red", "gray", "chocolate", "black", "brown", "orange",
"pink", "green", "purple", "white", "yellow")
levels(mushrooms$stalk_shape) <- c("enlarging", "tapering")
levels(mushrooms$stalk_root) <- c("missing", "bulbous", "club", "equal", "rooted")
levels(mushrooms$stalk_surface_above_ring) <- c("fibrous", "silky", "smooth", "scaly")
levels(mushrooms$stalk_surface_below_ring) <- c("fibrous", "silky", "smooth", "scaly")
levels(mushrooms$stalk_color_above_ring) <- c("buff", "cinnamon", "red", "gray", "brown", "pink",
"green", "purple", "white", "yellow")
levels(mushrooms$stalk_color_below_ring) <- c("buff", "cinnamon", "red", "gray", "brown", "pink",
"green", "purple", "white", "yellow")
levels(mushrooms$veil_type) <- c("partial","universal")
levels(mushrooms$veil_color) <- c("brown", "orange", "white", "yellow")
levels(mushrooms$ring_number) <- c("none", "one", "two")
levels(mushrooms$ring_type) <- c("evanescent", "flaring", "large", "none", "pendant")
levels(mushrooms$spore_print_color) <- c("buff", "chocolate", "black", "brown", "orange",
"green", "purple", "white", "yellow")
levels(mushrooms$population) <- c("abundant", "clustered", "numerous", "scattered", "several", "solitary")
levels(mushrooms$habitat) <- c("wood", "grasses", "leaves", "meadows", "paths", "urban", "waste")
Let’s take quick look at our dataset by displaying the first 6 rows.
head(mushrooms)
## class cap_shape cap_surface cap_color bruises odor gill_attachment
## 1 poisonous convex scaly brown yes pungent free
## 2 edible convex scaly yellow yes almond free
## 3 edible bell scaly white yes anise free
## 4 poisonous convex smooth white yes pungent free
## 5 edible convex scaly gray no none free
## 6 edible convex smooth yellow yes almond free
## gill_spacing gill_size gill_color stalk_shape stalk_root
## 1 close narrow black enlarging equal
## 2 close broad black enlarging club
## 3 close broad brown enlarging club
## 4 close narrow brown enlarging equal
## 5 crowded broad black tapering equal
## 6 close broad brown enlarging club
## stalk_surface_above_ring stalk_surface_below_ring stalk_color_above_ring
## 1 smooth smooth purple
## 2 smooth smooth purple
## 3 smooth smooth purple
## 4 smooth smooth purple
## 5 smooth smooth purple
## 6 smooth smooth purple
## stalk_color_below_ring veil_type veil_color ring_number ring_type
## 1 purple partial white one pendant
## 2 purple partial white one pendant
## 3 purple partial white one pendant
## 4 purple partial white one pendant
## 5 purple partial white one evanescent
## 6 purple partial white one pendant
## spore_print_color population habitat
## 1 black scattered urban
## 2 brown numerous grasses
## 3 brown numerous meadows
## 4 black scattered urban
## 5 brown abundant grasses
## 6 black numerous grasses
Now, we will explore the relationships between select variables with the outcome as well as with each other!
First, let’s explore the distribution of our outcome, Class. In summary, below we can see the distribution of mushrooms and their respective edibility plot. As a reminder, a mushroom’s “edible” means it is edible and a mushroom’s “poisonous” means it is not edible.
mushrooms %>%
ggplot(aes(x = class)) +
geom_bar() +
labs(x = "Mushroom Class", y = "# of Mushrooms", title = "Distribution of the Number of Mushrooms per class")
As we can see from the distribution plot of mushrooms. There are over 4000 mushrooms are edible and less than 4000 mushrooms are poisonous. It looks like there are still a lot more edible mushrooms than inedible ones!
Next, we need to determine the effect of the mushroom’s cap_surface and cap_color on whether the mushroom is poisonous or not. The following graph will show the distribution of the effect of different cap_surface and cap_color on mushroom’s edibility. Green means the mushroom is edible and red means the mushroom is poisonous.
ggplot(mushrooms, aes(x = cap_surface, y = cap_color, col = class)) +
geom_jitter(alpha = 0.5) +
scale_color_manual(breaks = c("edible", "poisonous"),
values = c("green", "red"))
As we can see on the plot, if we want to stay safe, better bet on fibrous surface, except they are yellow or gray. Or also we can bet on mushrooms with scaly cap_surface and yellow cap_color. Stay especially away from smooth surface, except if they are purple or green. And also stay away from mushrooms with grooves cap_surface and white cap _color.
Next, we need to determine the effect of the mushroom’s cap_shape and cap_color on whether the mushroom is poisonous or not. The following graph will show the distribution of the effect of different cap_shape and cap_color on mushroom’s edibility. Green means the mushroom is edible and red means the mushroom is poisonous.
ggplot(mushrooms, aes(x = cap_shape, y = cap_color, col = class)) +
geom_jitter(alpha = 0.5) +
scale_color_manual(breaks = c("edible", "poisonous"),
values = c("green", "red"))
As we can see from the distribution plot above, it is better to stay away from all shapes except maybe for bell shape mushrooms except buff and pink cap_color.
Next, we need to determine the effect of the mushroom’s gill_color and cap_color on whether the mushroom is poisonous or not. The following graph will show the distribution of the effect of different gill_color and cap_color on mushroom’s edibility. Green means the mushroom is edible and red means the mushroom is poisonous.
ggplot(mushrooms, aes(x = gill_color, y = cap_color, col = class)) +
geom_jitter(alpha = 0.5) +
scale_color_manual(breaks = c("edible", "poisonous"),
values = c("green", "red"))
Wow! It seems like we have more choice on the gill_color of the mushrooms. if we want to stay safe, better bet on red black brown pink purple gill color mushrooms, except their cap_color are brown or white or pink or gray. Stay especially away from smooth surface, except if they are purple or green. And also stay away from mushrooms with buff gill_color and red or brown cap _color.
Last, We are going to seee the distribution between odor and edibility of mushrooms.
ggplot(mushrooms, aes(x = class, y = odor, col = class)) +
geom_jitter(alpha = 0.5) +
scale_color_manual(breaks = c("edible", "poisonous"),
values = c("green", "red"))
Odor is definetely quite an informative predictor. Basically, if it smells fishy, spicy, pungent, foul or creosote just stay away. If it smells like anise or almond you can go ahead. If it doesn’t smell anything, you have better chance that it is edible than not. But there’s still possibility of a none odor mushroom is poisonous, so don’t get careless!
Now that we have a better idea of how most important variables affect how poisonous a mushroom is. It’s time to start fitting models to our data to see if we really can predict the class of a mushroom based on the predictors we have. However, we first have to set up our data by splitting it, creating the recipe, and creating folds for k-fold cross validation.
As we approach the building of our models, we first need to split our data into separate data sets. One will be used for the training of our models, and one will be the testing set. Our first step is to set our seed so that our random split can be reproduced every time we train our models. Next, we will perform a training / testing split on our data, and stratify on our response variable, class.
set.seed(0000)
# Splitting the data (70/30 split, stratify on capture rate)
mushrooms_split <- initial_split(mushrooms, prop = 0.7, strata = class)
mushrooms_train <- training(mushrooms_split)
mushrooms_test <- testing(mushrooms_split)
Dimension of our training data:
dim(mushrooms_train)
## [1] 5686 23
Dimension of our testing data:
dim(mushrooms_test)
## [1] 2438 23
Check if the data has been split correctly.
nrow(mushrooms_train)/nrow(mushrooms)
## [1] 0.6999015
The training set has about 70% of the data and the testing set has about 30% of the data. So, the data was split correctly between the training and testing sets.
We will now bring together our predictors and our response variable to build our recipe which we will use for all the models. We will use all the predictor variables except class in our recipe because they are all very useful for our predicted results. This dataset is almost all factors. We convert all the nominal - or non-numeric variables - to dummy variables except for our outcome. Below, you can see our complete formulation.
# Builiding our recipe
mushrooms_recipe <-
recipe(class ~., data = mushrooms_train) %>%
step_other(all_predictors(), threshold = 0.05) %>%
step_dummy(all_nominal_predictors())
We’ll stratify our cross validation on our response variable, class, as well as use 10 folds to perform stratifies cross validation.
mushrooms_folds <- vfold_cv(mushrooms_train, v = 10, strata = class)
It is now time to build our models! Since the models take a very long time to run, the results from each of the models has been saved to avoid rerunning the models every time. I have chosen Area Under the Receiver Operating Characteristic Curve (ROC_AUC) as my metric because it is the metric that is used to measure how well the model can distinguish two classes.The ROC_AUC is one of the most commonly used metrics for classification problem. The better the classification algorithm is, the higher the area under the roc curve. I have fit 4 models to the Mushrooms data, however, we will only be conducting further analysis on the 2 best-performing models. Let’s get to building our models!
Each of the models had a very similar process. I will detail it below and include the code for each of the models under that step (the code will not be evaluated here to save time).
For each of the models, we must conduct these steps to fit them:
Step 1: Set up the model by specifying the model you wish to fit, the parameters you want to tune, the engine the model comes from, and the mode (regression or classification) if necessary.
#KNN model
knn_model <- nearest_neighbor(neighbors = tune()) %>%
set_mode("classification") %>%
set_engine("kknn")
#Logistic Regression
log_reg <- logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification")
#Random Forest
rf_class_spec <- rand_forest(mtry = tune(),
trees = tune(),
min_n = tune()) %>%
set_engine("ranger") %>%
set_mode("classification")
# decision tree
mushrooms_tree_spec<-decision_tree() %>%
set_mode("classification") %>%
set_engine("rpart")
Step 2: Set up the workflow for the model and add the model and the recipe.
# KNN model
knn_workflow <- workflow() %>%
add_model(knn_model) %>%
add_recipe(mushrooms_recipe)
# Logistic Regression
mushroomslog_wkflow <- workflow() %>%
add_model(log_reg) %>%
add_recipe(mushrooms_recipe)
# Random Forest
rf_class_wf <- workflow() %>%
add_model(rf_class_spec) %>%
add_recipe(mushrooms_recipe)
# Decicision Tree
mushrooms_tree_wf <- workflow() %>%
add_recipe(mushrooms_recipe) %>%
add_model(mushrooms_tree_spec %>%
set_args(cost_complexity = tune()))
Step 3: Create a tuning grid to specify the ranges of the parameters you wish to tune as well as how many levels of each.
# KNN Model
knn_grid <- grid_regular(neighbors(range = c(1,10)), levels = 10)
# Logistic Regression
mushroomslog_fit <- fit(mushroomslog_wkflow, mushrooms_train)
# Random Forest
rf_grid <- grid_regular(mtry(range = c(1, 10)),
trees(range = c(200, 1000)),
min_n(range = c(10, 20)),
levels = 5)
# Decision Tree
dt_grid <- grid_regular((cost_complexity(range = c(-3,-1))))
Step 4: Tune the model and specify the workflow, k-fold cross validation folds, and the tuning grid for our chosen parameters to tune.
#knn model
knn_tune <- tune_grid(
knn_workflow,
resamples = mushrooms_folds,
grid = knn_grid)
knn_tune
## # Tuning results
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [5116/570]> Fold01 <tibble [20 × 5]> <tibble [0 × 3]>
## 2 <split [5117/569]> Fold02 <tibble [20 × 5]> <tibble [0 × 3]>
## 3 <split [5117/569]> Fold03 <tibble [20 × 5]> <tibble [0 × 3]>
## 4 <split [5117/569]> Fold04 <tibble [20 × 5]> <tibble [0 × 3]>
## 5 <split [5117/569]> Fold05 <tibble [20 × 5]> <tibble [0 × 3]>
## 6 <split [5118/568]> Fold06 <tibble [20 × 5]> <tibble [0 × 3]>
## 7 <split [5118/568]> Fold07 <tibble [20 × 5]> <tibble [0 × 3]>
## 8 <split [5118/568]> Fold08 <tibble [20 × 5]> <tibble [0 × 3]>
## 9 <split [5118/568]> Fold09 <tibble [20 × 5]> <tibble [0 × 3]>
## 10 <split [5118/568]> Fold10 <tibble [20 × 5]> <tibble [0 × 3]>
#logistic Regression
mushroomslog_fit <- fit(mushroomslog_wkflow, mushrooms_train)
predict(mushroomslog_fit,new_data=mushrooms_train,type="prob")
## # A tibble: 5,686 × 2
## .pred_edible .pred_poisonous
## <dbl> <dbl>
## 1 1.00 4.53e-12
## 2 1.00 3.63e-12
## 3 1.00 3.82e-12
## 4 1.00 6.17e-12
## 5 1.00 1.22e-12
## 6 1.00 8.65e-13
## 7 1.00 8.02e-13
## 8 1.00 8.11e-13
## 9 1.00 1.57e-12
## 10 1.00 3.72e-12
## # … with 5,676 more rows
mushrooms_log_kfold_fit <- fit_resamples(mushroomslog_wkflow,mushrooms_folds)
mushrooms_log_kfold_fit
## # Resampling results
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [5116/570]> Fold01 <tibble [2 × 4]> <tibble [2 × 3]>
## 2 <split [5117/569]> Fold02 <tibble [2 × 4]> <tibble [2 × 3]>
## 3 <split [5117/569]> Fold03 <tibble [2 × 4]> <tibble [2 × 3]>
## 4 <split [5117/569]> Fold04 <tibble [2 × 4]> <tibble [2 × 3]>
## 5 <split [5117/569]> Fold05 <tibble [2 × 4]> <tibble [2 × 3]>
## 6 <split [5118/568]> Fold06 <tibble [2 × 4]> <tibble [2 × 3]>
## 7 <split [5118/568]> Fold07 <tibble [2 × 4]> <tibble [2 × 3]>
## 8 <split [5118/568]> Fold08 <tibble [2 × 4]> <tibble [2 × 3]>
## 9 <split [5118/568]> Fold09 <tibble [2 × 4]> <tibble [2 × 3]>
## 10 <split [5118/568]> Fold10 <tibble [2 × 4]> <tibble [2 × 3]>
##
## There were issues with some computations:
##
## - Warning(s) x10: glm.fit: algorithm did not converge, glm.fit: fitted probabilitie... - Warning(s) x10: glm.fit: algorithm did not converge, glm.fit: fitted probabilitie...
##
## Run `show_notes(.Last.tune.result)` for more information.
#Random Forest
rf_tune_res <- tune_grid(
rf_class_wf,
resamples = mushrooms_folds,
grid = rf_grid)
rf_tune_res
## # Tuning results
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [5116/570]> Fold01 <tibble [250 × 7]> <tibble [0 × 3]>
## 2 <split [5117/569]> Fold02 <tibble [250 × 7]> <tibble [0 × 3]>
## 3 <split [5117/569]> Fold03 <tibble [250 × 7]> <tibble [0 × 3]>
## 4 <split [5117/569]> Fold04 <tibble [250 × 7]> <tibble [0 × 3]>
## 5 <split [5117/569]> Fold05 <tibble [250 × 7]> <tibble [0 × 3]>
## 6 <split [5118/568]> Fold06 <tibble [250 × 7]> <tibble [0 × 3]>
## 7 <split [5118/568]> Fold07 <tibble [250 × 7]> <tibble [0 × 3]>
## 8 <split [5118/568]> Fold08 <tibble [250 × 7]> <tibble [0 × 3]>
## 9 <split [5118/568]> Fold09 <tibble [250 × 7]> <tibble [0 × 3]>
## 10 <split [5118/568]> Fold10 <tibble [250 × 7]> <tibble [0 × 3]>
# decision tree
dt_tune_res <- tune_grid(
mushrooms_tree_wf,
resamples = mushrooms_folds,
grid = dt_grid,
metrics = metric_set(yardstick::roc_auc))
dt_tune_res
## # Tuning results
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [5116/570]> Fold01 <tibble [3 × 5]> <tibble [0 × 3]>
## 2 <split [5117/569]> Fold02 <tibble [3 × 5]> <tibble [0 × 3]>
## 3 <split [5117/569]> Fold03 <tibble [3 × 5]> <tibble [0 × 3]>
## 4 <split [5117/569]> Fold04 <tibble [3 × 5]> <tibble [0 × 3]>
## 5 <split [5117/569]> Fold05 <tibble [3 × 5]> <tibble [0 × 3]>
## 6 <split [5118/568]> Fold06 <tibble [3 × 5]> <tibble [0 × 3]>
## 7 <split [5118/568]> Fold07 <tibble [3 × 5]> <tibble [0 × 3]>
## 8 <split [5118/568]> Fold08 <tibble [3 × 5]> <tibble [0 × 3]>
## 9 <split [5118/568]> Fold09 <tibble [3 × 5]> <tibble [0 × 3]>
## 10 <split [5118/568]> Fold10 <tibble [3 × 5]> <tibble [0 × 3]>
dt_tune_res2 <- tune_grid(
mushrooms_tree_wf,
resamples = mushrooms_folds,
grid = dt_grid,
metrics = metric_set(yardstick::accuracy))
dt_tune_res2
## # Tuning results
## # 10-fold cross-validation using stratification
## # A tibble: 10 × 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [5116/570]> Fold01 <tibble [3 × 5]> <tibble [0 × 3]>
## 2 <split [5117/569]> Fold02 <tibble [3 × 5]> <tibble [0 × 3]>
## 3 <split [5117/569]> Fold03 <tibble [3 × 5]> <tibble [0 × 3]>
## 4 <split [5117/569]> Fold04 <tibble [3 × 5]> <tibble [0 × 3]>
## 5 <split [5117/569]> Fold05 <tibble [3 × 5]> <tibble [0 × 3]>
## 6 <split [5118/568]> Fold06 <tibble [3 × 5]> <tibble [0 × 3]>
## 7 <split [5118/568]> Fold07 <tibble [3 × 5]> <tibble [0 × 3]>
## 8 <split [5118/568]> Fold08 <tibble [3 × 5]> <tibble [0 × 3]>
## 9 <split [5118/568]> Fold09 <tibble [3 × 5]> <tibble [0 × 3]>
## 10 <split [5118/568]> Fold10 <tibble [3 × 5]> <tibble [0 × 3]>
Step 5: Save the tuned models to an RDS file to avoid rerunning the model.
# KNN Model
save(knn_tune, file = "knn_tune.rda")
# Logistic Regression
save(mushrooms_log_kfold_fit, file = "log_fit.rda")
# Random Forest
save(rf_tune_res, file = "rf_tune_res.rda")
# Decision Tree
save(dt_tune_res, file = "dt_tune_res.rda")
save(dt_tune_res2, file = "dt_tune_res2.rda")
It’s finally time to compare the results of all of our models and see which ones performed the best!
In order to summarize the best ROC AUC values from our four models, we will use collect_metrics() to print the mean and standard errors of the performance metric Area Under the Receiver Operating Characteristic Curve (ROC_AUC) for each model across folds. Then find out which model performed the best.
#knn
collect_metrics(knn_tune)
## # A tibble: 20 × 7
## neighbors .metric .estimator mean n std_err .config
## <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 accuracy binary 1.00 10 0.000176 Preprocessor1_Model01
## 2 1 roc_auc binary 1.00 10 0.000182 Preprocessor1_Model01
## 3 2 accuracy binary 1.00 10 0.000176 Preprocessor1_Model02
## 4 2 roc_auc binary 1.00 10 0.000182 Preprocessor1_Model02
## 5 3 accuracy binary 1.00 10 0.000176 Preprocessor1_Model03
## 6 3 roc_auc binary 1 10 0 Preprocessor1_Model03
## 7 4 accuracy binary 1.00 10 0.000176 Preprocessor1_Model04
## 8 4 roc_auc binary 1 10 0 Preprocessor1_Model04
## 9 5 accuracy binary 1.00 10 0.000176 Preprocessor1_Model05
## 10 5 roc_auc binary 1 10 0 Preprocessor1_Model05
## 11 6 accuracy binary 1.00 10 0.000235 Preprocessor1_Model06
## 12 6 roc_auc binary 1 10 0 Preprocessor1_Model06
## 13 7 accuracy binary 1.00 10 0.000235 Preprocessor1_Model07
## 14 7 roc_auc binary 1 10 0 Preprocessor1_Model07
## 15 8 accuracy binary 1.00 10 0.000235 Preprocessor1_Model08
## 16 8 roc_auc binary 1 10 0 Preprocessor1_Model08
## 17 9 accuracy binary 1.00 10 0.000235 Preprocessor1_Model09
## 18 9 roc_auc binary 1 10 0 Preprocessor1_Model09
## 19 10 accuracy binary 1.00 10 0.000235 Preprocessor1_Model10
## 20 10 roc_auc binary 1 10 0 Preprocessor1_Model10
show_best(knn_tune, metric = "roc_auc")
## # A tibble: 5 × 7
## neighbors .metric .estimator mean n std_err .config
## <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 3 roc_auc binary 1 10 0 Preprocessor1_Model03
## 2 4 roc_auc binary 1 10 0 Preprocessor1_Model04
## 3 5 roc_auc binary 1 10 0 Preprocessor1_Model05
## 4 6 roc_auc binary 1 10 0 Preprocessor1_Model06
## 5 7 roc_auc binary 1 10 0 Preprocessor1_Model07
#logistic
collect_metrics(mushrooms_log_kfold_fit)
## # A tibble: 2 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 1.00 10 0.000176 Preprocessor1_Model1
## 2 roc_auc binary 1 10 0 Preprocessor1_Model1
show_best(mushrooms_log_kfold_fit, metric = "roc_auc")
## # A tibble: 1 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 roc_auc binary 1 10 0 Preprocessor1_Model1
#random forest
collect_metrics(rf_tune_res)
## # A tibble: 250 × 9
## mtry trees min_n .metric .estimator mean n std_err .config
## <int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 200 10 accuracy binary 0.929 10 0.00595 Preprocessor1_Mod…
## 2 1 200 10 roc_auc binary 0.998 10 0.000311 Preprocessor1_Mod…
## 3 3 200 10 accuracy binary 1.00 10 0.000235 Preprocessor1_Mod…
## 4 3 200 10 roc_auc binary 1 10 0 Preprocessor1_Mod…
## 5 5 200 10 accuracy binary 1.00 10 0.000235 Preprocessor1_Mod…
## 6 5 200 10 roc_auc binary 1 10 0 Preprocessor1_Mod…
## 7 7 200 10 accuracy binary 1.00 10 0.000235 Preprocessor1_Mod…
## 8 7 200 10 roc_auc binary 1 10 0 Preprocessor1_Mod…
## 9 10 200 10 accuracy binary 1.00 10 0.000235 Preprocessor1_Mod…
## 10 10 200 10 roc_auc binary 1 10 0 Preprocessor1_Mod…
## # … with 240 more rows
show_best(rf_tune_res, metric = "roc_auc")
## # A tibble: 5 × 9
## mtry trees min_n .metric .estimator mean n std_err .config
## <int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 3 200 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
## 2 5 200 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
## 3 7 200 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
## 4 10 200 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
## 5 3 400 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
#decision tree
collect_metrics(dt_tune_res)
## # A tibble: 3 × 7
## cost_complexity .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.001 roc_auc binary 1.00 10 0.000189 Preprocessor1_Model1
## 2 0.01 roc_auc binary 0.978 10 0.00244 Preprocessor1_Model2
## 3 0.1 roc_auc binary 0.946 10 0.00374 Preprocessor1_Model3
show_best(dt_tune_res, metric = "roc_auc")
## # A tibble: 3 × 7
## cost_complexity .metric .estimator mean n std_err .config
## <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.001 roc_auc binary 1.00 10 0.000189 Preprocessor1_Model1
## 2 0.01 roc_auc binary 0.978 10 0.00244 Preprocessor1_Model2
## 3 0.1 roc_auc binary 0.946 10 0.00374 Preprocessor1_Model3
As we can see in our tibble, the Random Forest model and logistic regression performed the best overall with a ROC AUC score of 1, with the KNN model close behind at 0.9998. Of course, this is only fitted on the training data, so our models still need to perform on our testing data that we have reserved for exactly this. We will be moving forward with Random Forest model and logistic regression model.Let’s now visualize the results!
One of the most useful tools for visualizing the results of models that have been tuned is the autoplot function in r. This will visualize the effects that the change in certain parameters has on our metric of choice, roc_auc. We will display 3 plots of our models with the roc_auc and the accuracy value, since we can’t use the autoplot function for our logistic regression model.
#knn
autoplot(knn_tune)
For the KNN model, there is a sudden drop in the accuracy of our predictions from very high at the beginning, while the roc_auc values have remained smoothly high. In general, the KNN model performs well.
#random forest
autoplot(rf_tune_res) + theme_minimal()
For the random forest, we tuned the the minimal node size, the number of randomly selected predictors, and the number of trees. From the plots, it appears that, overall, all number of predictors renders a better performance. Based on the roc_auc values in the plots, this model definitely performs the best!
#decision tree
autoplot(dt_tune_res)
autoplot(dt_tune_res2)
For the decision tree, the graph shows that both the accuracy and roc_auc values start out high and slowly keep decreasing. In general, compared with other models, the prediction of decision tree is not as good.
Now that we know random forest, logistic regression all performed very good, we can start to progress to analyzing its true results.
Since our logistic regression and random forest both had the best overall performance, we now want to examine how strong it is on data it has not seen yet. The high ROC AUC scores you saw above were the models’ ability to predict a mushroom’s edibility using the same data it was originally trained on, thus explaining its strong results.
Logistic Regression 1! The logistic regression model #1 seemed to have performed the best overall from all the logistic regression models. This is on top of being the best of the four different prediction models. Below is the model’s output and scores, as well as the its associated parameters.
show_best(mushrooms_log_kfold_fit, metric = "roc_auc")
## # A tibble: 1 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 roc_auc binary 1 10 0 Preprocessor1_Model1
Random Forest 2 3 4 5 7! The random forest model #2 #3 #4 #5 #7 seemed all performed the best overall from all the random forest models. This is on top of being the best of the four different prediction models. Below is the model’s output and scores, as well as the its associated parameters.
show_best(rf_tune_res, metric = "roc_auc")
## # A tibble: 5 × 9
## mtry trees min_n .metric .estimator mean n std_err .config
## <int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 3 200 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
## 2 5 200 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
## 3 7 200 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
## 4 10 200 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
## 5 3 400 10 roc_auc binary 1 10 0 Preprocessor1_Model0…
Now that we have our best overall models, we can finally fit it to our testing data and discover its actual performance in predicting the edibility of mushrooms.
Now, we will take that best models from the tuned random forest and logistic regression and fit it to the training data. This will train that random forest and logistic regression one more time on the entire training data set. Once we have fit and trained the random forest and the logistic regression on the training data, it will be ready for testing!
#Logistic Regression
final_log_fit <- fit(mushroomslog_wkflow, mushrooms_train)
#Random Forest
best_rf_train <- select_best(rf_tune_res, metric = 'roc_auc')
rf_final_workflow_train <- finalize_workflow(rf_class_wf, best_rf_train)
rf_final_fit_train <- fit(rf_final_workflow_train, data = mushrooms_train)
Now, it’s time to test our random forest and logistic regression model to see how it performs on data that it has not been trained on at all: the testing data set.
# Logistic Regression
augment(mushroomslog_fit, new_data = mushrooms_train) %>%
conf_mat(truth = class, estimate = .pred_class) %>%
autoplot(type = "heatmap")
# Random Forest
rf_final_fit_train %>% extract_fit_parsnip()
## parsnip model object
##
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~3L, x), num.trees = ~200L, min.node.size = min_rows(~10L, x), num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE)
##
## Type: Probability estimation
## Number of trees: 200
## Sample size: 5686
## Number of independent variables: 69
## Mtry: 3
## Target node size: 10
## Variable importance mode: none
## Splitrule: gini
## OOB prediction error (Brier s.): 0.009642352
final_rf_model_test <- augment(rf_final_fit_train,
mushrooms_test) %>%
select(class, starts_with(".pred"))
roc_data <- final_rf_model_test %>%
dplyr::select(starts_with(".pred_"))
conf_mat(final_rf_model_test, truth = class, .pred_class) %>%
autoplot(type = "heatmap")
Wow! It’s so delightful! From the above plot of fitting in testing data results, both logistic regression model and random forest model did a perfect job in predicting the edibility of mushrooms! Both models can predict very correctly whether a mushroom is edible or not based on its characteristics. I think this result is very good, which shows that our previous efforts were not in vain and the models were very successful!
Throughout the project, we studied, explored and analyzed our data and its variables in order to build and test a model that could predict the edibility of different mushrooms. After continuous analysis, testing and calculations, we can say that the random forest model and the logistic regression model are the best in terms of predicting the practicality of mushrooms. Not only is it the best, but it is also very accurate.
For the future, I think we can use uploading photos of mushrooms online to collect data on different mushrooms and analyze whether this mushroom is edible or not based on the model we have created. This facilitates a lot for the future health aspect of human life. I believe that in practice, this could be combined with a model such as this project, in order to achieve an even higher ROC AUC score and accuracy. The higher the accuracy and roc_auc score is, the higher the ROC AUC score is. the more accurate the prediction of the edibility of the mushrooms.
If we put it into reality, our model performance is already very perfect. Because I think that although our dataset has a very comprehensive characterization of different mushrooms, it still can’t collect data of all mushrooms in the world. So when we really live or explore in the wild, it is still risky to try different wild mushrooms. Although our model predicts very well in the available data, it is still not completely safe to eat wild mushrooms, and we still need to pay attention to the health risks in this area. After all, we cannot predict how powerful the mushroom toxicity hazard is either. The fact that I can accurately predict the edibility of mushrooms in the available data is a huge achievement and something I am proud of. It was a great opportunity for me to gain experience and skills in machine learning techniques and to explore a subject that is very interesting and meaningful to me.
Overall, the biggest thing I got out of this project was my renewed appreciation for mushrooms. I think I will slowly change my aversion to mushrooms and try to experience them. And the fact that I chose a topic for this project that was outside of my knowledge and interest allowed me to push myself and be creative!